{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "restricted-republic",
   "metadata": {},
   "source": [
    "# Assess predictions on multiclass wine data with a DNN model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adolescent-fusion",
   "metadata": {},
   "source": [
    "This notebook demonstrates the use of the `responsibleai` API to assess a DNN pytorch model trained on the multiclass wine dataset. It walks through the API calls necessary to create a widget with model analysis insights, then guides a visual analysis of the model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "exempt-cartoon",
   "metadata": {},
   "source": [
    "* [Launch Responsible AI Toolbox](#Launch-Responsible-AI-Toolbox)\n",
    "    * [Train a DNN Model](#Train-a-DNN-Model)\n",
    "    * [Create Model and Data Insights](#Create-Model-and-Data-Insights)\n",
    "* [Assess Your Model](#Assess-Your-Model)\n",
    "    * [Aggregate Analysis](#Aggregate-Analysis)\n",
    "    * [Individual Analysis](#Individual-Analysis)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "continent-dream",
   "metadata": {},
   "source": [
    "## Launch Responsible AI Toolbox"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "welsh-crisis",
   "metadata": {},
   "source": [
    "The following section examines the code necessary to create datasets and a model. It then generates insights using the `responsibleai` API that can be visually analyzed."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "sophisticated-bryan",
   "metadata": {},
   "source": [
    "### Train a DNN Model\n",
    "*The following section can be skipped. It loads a dataset and trains a model for illustrative purposes.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "indie-message",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sklearn\n",
    "import zipfile\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "torch.manual_seed(0)\n",
    "\n",
    "from sklearn.datasets import load_wine\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
    "from sklearn.compose import ColumnTransformer\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c960272",
   "metadata": {},
   "source": [
    "#### Load the wine data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c82acbfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "wine = load_wine()\n",
    "X = wine['data']\n",
    "y = wine['target']\n",
    "classes = wine['target_names']\n",
    "feature_names = wine['feature_names']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "412b58b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split data into train and test\n",
    "from sklearn.model_selection import train_test_split\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e815edc6",
   "metadata": {},
   "source": [
    "#### Define a simple pytorch classification model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd3ff4ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pytorch_net(numCols, numClasses=3):\n",
    "    class Net(nn.Module):\n",
    "        def __init__(self):\n",
    "            super(Net, self).__init__()\n",
    "            self.norm = nn.LayerNorm(numCols)\n",
    "            self.fc1 = nn.Linear(numCols, 100)\n",
    "            self.fc2 = nn.Dropout(p=0.2)\n",
    "            self.fc3 = nn.Linear(100, numClasses)\n",
    "            self.output = nn.Softmax()\n",
    "\n",
    "        def forward(self, X):\n",
    "            X = self.norm(X)\n",
    "            X = F.relu(self.fc1(X))\n",
    "            X = self.fc2(X)\n",
    "            X = self.fc3(X)\n",
    "            return self.output(X)\n",
    "    return Net()\n",
    "\n",
    "torch_X = torch.Tensor(X_train).float()\n",
    "torch_y = torch.Tensor(y_train).long()\n",
    "\n",
    "# Create network structure\n",
    "net = pytorch_net(X_train.shape[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "potential-proportion",
   "metadata": {},
   "source": [
    "#### Train the pytorch DNN classifier on the training data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "431e414b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the model\n",
    "epochs = 10000\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr=0.01)\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    optimizer.zero_grad()\n",
    "    out = net(torch_X)\n",
    "    loss = criterion(out, torch_y)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    print('epoch: ', epoch, ' loss: ', loss.data.item())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88ff4385",
   "metadata": {},
   "source": [
    "Wrap the model with scikit-learn style predict/predict_proba functions using the wrap_model function from https://github.com/microsoft/ml-wrappers to make it compatible with RAIInsights and the ResponsibleAIDashboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccaa360d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ml_wrappers import wrap_model, DatasetWrapper\n",
    "model = wrap_model(net, DatasetWrapper(X_train), model_task='classification')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "continued-praise",
   "metadata": {},
   "source": [
    "### Create Model and Data Insights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "residential-identification",
   "metadata": {},
   "outputs": [],
   "source": [
    "from raiwidgets import ResponsibleAIDashboard\n",
    "from responsibleai import RAIInsights"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cheap-juice",
   "metadata": {},
   "source": [
    "To use Responsible AI Toolbox, initialize a RAIInsights object upon which different components can be loaded.\n",
    "\n",
    "RAIInsights accepts the model, the full dataset, the test dataset, the target feature string and the task type string as its arguments.",
    "\n",
    "You may also create the `FeatureMetadata` container, identify any feature of your choice as the `identity_feature`, specify a list of strings of categorical feature names via the `categorical_features` parameter, and specify dropped features via the `dropped_features` parameter. The `FeatureMetadata` may also be passed into the `RAIInsights`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bulgarian-hepatitis",
   "metadata": {},
   "outputs": [],
   "source": [
    "from responsibleai.feature_metadata import FeatureMetadata\n",
    "feature_metadata = FeatureMetadata(categorical_features=[], dropped_features=[])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bulgarian-hepatitis",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_feature = 'wine'\n",
    "X_train = pd.DataFrame(X_train, columns=feature_names)\n",
    "X_test = pd.DataFrame(X_test, columns=feature_names)\n",
    "X_train[target_feature] = y_train\n",
    "X_test[target_feature] = y_test\n",
    "\n",
    "rai_insights = RAIInsights(model, X_train, X_test, target_feature, 'classification',\n",
    "                           feature_metadata=feature_metadata)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "original-rolling",
   "metadata": {},
   "source": [
    "Add the components of the toolbox that are focused on model assessment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "governing-antique",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Interpretability\n",
    "rai_insights.explainer.add()\n",
    "# Error Analysis\n",
    "rai_insights.error_analysis.add()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "unexpected-bicycle",
   "metadata": {},
   "source": [
    "Once all the desired components have been loaded, compute insights on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "average-calibration",
   "metadata": {},
   "outputs": [],
   "source": [
    "rai_insights.compute()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "elder-fleet",
   "metadata": {},
   "source": [
    "Finally, visualize and explore the model insights. Use the resulting widget or follow the link to view this in a new tab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "thousand-louis",
   "metadata": {},
   "outputs": [],
   "source": [
    "ResponsibleAIDashboard(rai_insights)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
